import os
import pickle
import re
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.linear_model import LinearRegression

# set palette
# blue green red purple orange
custom_palette = ["#033473", '#207f4c', "#a60303", '#61649f', '#f2ae30']
sns.set_theme(style="whitegrid", palette=custom_palette, font="Times New Roman",
              font_scale=1.3, rc={"mathtext.fontset": "stix"})


# ================================================================== #
# 🚀 sample
# ================================================================== #

def load_sample_results(save_dir: str) -> dict:
    results = {}
    for file_name in os.listdir(os.path.join(save_dir, 'kl')):
        # file name example: result_n=None_run=None.pkl
        match = re.match(r"result_n=([^_]+)_run=([^_]+)\.pkl", file_name)
        if match:
            n = match.group(1)
            with open(os.path.join(save_dir, 'kl', file_name), 'rb') as f:
                result = pickle.load(f)
                results[('kl', n)] = result
    for file_name in os.listdir(os.path.join(save_dir, 'f')):
        # file name example: result_n=None_run=None.pkl
        match = re.match(r"result_n=([^_]+)_run=([^_]+)\.pkl", file_name)
        if match:
            n = match.group(1)
            with open(os.path.join(save_dir, 'f', file_name), 'rb') as f:
                result = pickle.load(f)
                results[('f', n)] = result
    return results


def compute_norm(results: dict) -> pd.DataFrame:
    # get v_population (n=None)
    # v_population = results["None"]["v"]
    # compute norm
    df = []
    for key, result in results.items():
        n = result["n"]
        v_population = results[(key[0], "None")]["v"]
        if n is not None:
            v = result["v"]
            error = np.linalg.norm(v - v_population, ord=np.inf)
            df.append({"divergence": key[0], "n": n, "error": error})
    df = pd.DataFrame(df)
    return df


def compute_lr_coef(df: pd.DataFrame):
    linear_regression = LinearRegression()
    X = np.log(df["n"]).values.reshape(-1, 1)
    y = np.log(df["error"]).values
    linear_regression.fit(X, y)
    predictions = linear_regression.predict(X)
    df["prediction"] = np.exp(predictions)
    return df, linear_regression.intercept_, linear_regression.coef_


def plot_sample_n(save_dir):
    results = load_sample_results(save_dir)
    results_kl = {key: value for key, value in results.items() if key[0] == "kl"}
    results_f = {key: value for key, value in results.items() if key[0] == "f"}
    df_kl = compute_norm(results_kl)
    df_f = compute_norm(results_f)
    df_kl["n"] = df_f["n"] * 16 * 6
    df_f["n"] = df_f["n"] * 16 * 6
    df_kl, intercept_kl, coef_kl = compute_lr_coef(df_kl)
    df_f, intercept_f, coef_f = compute_lr_coef(df_f)
    df = pd.concat([df_kl, df_f], axis=0)

    plt.figure(figsize=(6, 4))
    sns.scatterplot(data=df_kl, x="n", y="error")
    sns.lineplot(data=df_kl, x="n", y="prediction",
                 label=rf"KL, slope = ${coef_kl[0]:.2f}$")
    plt.yscale("log")
    plt.xscale("log")
    plt.xlabel(r"$n$")
    plt.ylabel(r"$\epsilon$")
    plt.tight_layout()
    plt.savefig("imgs/n_kl.png", dpi=1200)
    plt.show()

    plt.figure(figsize=(6, 4))
    sns.scatterplot(data=df_f, x="n", y="prediction")

    sns.lineplot(data=df_f, x="n", y="prediction",
                 label=rf"$\chi^2$, slope = ${coef_f[0]:.2f}$")
    plt.yscale("log")
    plt.xscale("log")
    plt.xlabel(r"$n$")
    plt.ylabel(r"$\epsilon$")
    plt.tight_layout()
    plt.savefig("imgs/n_f.png", dpi=1200)
    plt.show()


# ================================================================== #
# 🚀 general mdp
# ================================================================== #
def load_results_general_mdp():
    results = {}
    # load kl results
    for result_name in os.listdir('results/general_mdp/kl/'):
        with open(os.path.join('results/general_mdp/kl/', result_name), 'rb') as f:
            result = pickle.load(f)
        for k, v in result.items():
            key = (*k, "kl")
            results[key] = v
    for result_name in os.listdir('results/general_mdp/f/'):
        with open(os.path.join('results/general_mdp/f/', result_name), 'rb') as f:
            result = pickle.load(f)
        for k, v in result.items():
            key = (k[0], k[1], k[2], k[3], "f")
            results[key] = v
    return results


def plot_general_mdp_(df, x, divergence, divergence_name, xlim, ylim):
    plt.figure(figsize=(6, 4))
    y = "S" if x == "A" else "A"
    df_sub = df[(df[y] == 65) & (df["divergence"] == divergence)]
    df_sub = df_sub[df_sub["run"] == 0]
    df_group = df_sub.groupby(["S", "A", "n", "divergence"])["error"].mean().reset_index()
    for n in df_group["n"].unique():
        df_n = df_group[df_group["n"] == n]
        plt.scatter(df_n[x], df_n["error"], alpha=0.7)
        # Linear regression on S vs. error
        X = df_n[x].values.reshape(-1, 1)
        Y = df_n["error"].values
        lin_reg = LinearRegression()
        lin_reg.fit(X, Y)
        intercept = lin_reg.intercept_
        coef = lin_reg.coef_[0]
        x_range = np.linspace(df_n[x].min(), df_n[x].max(), 100)
        plt.plot(x_range, intercept + coef * x_range,
                 label=fr"{divergence_name}, $n_0$={n}")
    plt.xlabel(rf"$|\mathcal{{{x}}}|$")
    plt.ylabel(r"$\epsilon/\log(|\mathcal{S}||\mathcal{A}|)$")
    # plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.legend(labelspacing=0.0, loc='upper right')
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.tight_layout()
    plt.savefig(rf"imgs/plot_{x}_{divergence}.png", dpi=1200)
    plt.show()


def plot_general_mdp():
    # load results
    results = load_results_general_mdp()
    df = []
    for key in results.keys():
        S, A, n, run, divergence = key
        if n is not None:
            v_population = results[(S, A, None, 0, divergence)]["v"]
            v_emperical = results[(S, A, n, run, divergence)]["v"]
            df.append({"S": S, "A": A, "n": n, "run": run, "divergence": divergence,
                       "error": np.linalg.norm(v_emperical - v_population, ord=np.inf) / (np.log(S * A))})
    df = pd.DataFrame(df)
    # select specific n values
    df = df[df["n"].isin([1000, 2000, 5000])]
    divergences = ["kl", "f"]

    # S
    # kl divergence
    plot_general_mdp_(df, "S", "kl", "KL", xlim=(0, None), ylim=(0.008, 0.05))
    plot_general_mdp_(df, "S", "f", r"$\chi^2$", xlim=(0, None), ylim=(0.015, 0.1))
    plot_general_mdp_(df, "A", "kl", "KL", xlim=(0, None), ylim=(0.008, 0.05))
    plot_general_mdp_(df, "A", "f", r"$\chi^2$", xlim=(0, None), ylim=(0.015, 0.1))
    # f divergence


plot_sample_n('results/inventory/sample')
plot_general_mdp()
